import java.util.*;

/*
public class ListNode {
    int val;
    ListNode next = null;

    ListNode(int val) {
        this.val = val;
    }
}*/
public class Partition {
    public ListNode partition(ListNode pHead, int x) {
        // write code here
        ListNode newhead1 = new ListNode(-1);
        ListNode newhead2 = new ListNode(-2);
        ListNode cur = pHead;
        ListNode poin = newhead1;
        ListNode poin2 = newhead2;
        while(cur != null) {
            if (cur.val < x) {
                poin.next = cur;
                cur = cur.next;
                poin = poin.next;
            }else {
                poin2.next = cur;
                cur = cur.next;
                poin2 = poin2.next;



            }
        }
        poin2.next = null;
        poin.next = newhead2.next;
        return newhead1.next;
    }
}